--- siren/siren/siren.py	2024-02-13 11:03:14.231374704 -0500
+++ external/google/siren/siren/siren.py	2024-02-13 10:46:24.484645120 -0500
@@ -38,7 +38,8 @@
                  w0_initial: float = 30.0,
                  bias: bool = True,
                  initializer: str = 'siren',
-                 c: float = 6):
+                 c: float = 6,
+                 use_sine_activation: bool = True):
         """
         SIREN model from the paper [Implicit Neural Representations with
         Periodic Activation Functions](https://arxiv.org/abs/2006.09661).
@@ -64,6 +65,8 @@
         :param c: value used to compute the bound in the siren intializer.
             defaults to 6
         :type c: float, optional
+        :param use_sine_activation: whether to use sine activation or not (a standard ReLU MLP).
+            defaults to True
 
         # References:
             -   [Implicit Neural Representations with Periodic Activation
@@ -71,19 +74,20 @@
         """
         super(SIREN, self).__init__()
         self._check_params(layers)
-        self.layers = [nn.Linear(in_features, layers[0], bias=bias), Sine(
-            w0=w0_initial)]
+        self.layers = [nn.Linear(in_features, layers[0], bias=bias), 
+                       Sine(w0=w0_initial) if use_sine_activation else nn.ReLU()]
+        self.use_sine_activation = use_sine_activation
 
         for index in range(len(layers) - 1):
             self.layers.extend([
                 nn.Linear(layers[index], layers[index + 1], bias=bias),
-                Sine(w0=w0)
+                Sine(w0=w0) if use_sine_activation else nn.ReLU()
             ])
 
         self.layers.append(nn.Linear(layers[-1], out_features, bias=bias))
         self.network = nn.Sequential(*self.layers)
 
-        if initializer is not None and initializer == 'siren':
+        if use_sine_activation and initializer is not None and initializer == 'siren':
             for m in self.network.modules():
                 if isinstance(m, nn.Linear):
                     siren_uniform_(m.weight, mode='fan_in', c=c)
